package nsalt.bus

import chisel3._
import chisel3.util._

import nsalt._ 

// More desgin details and principles can be found at
// https://oscpu.gitbook.io/nutshell/xi-tong-she-ji/bus

// This bus is explicitly designed for memory accessing.

sealed abstract class BusPort extends Bundle with Config

object MemoryOpcode
{
   val MT_X  = 0.asUInt(3.W)
   val MT_B  = 1.asUInt(3.W)
   val MT_H  = 2.asUInt(3.W)
   val MT_W  = 3.asUInt(3.W)
   val MT_D  = 4.asUInt(3.W)
   val MT_BU = 5.asUInt(3.W)
   val MT_HU = 6.asUInt(3.W)
   val MT_WU = 7.asUInt(3.W)

   val M_X   = "b0".asUInt(1.W)
   val M_XRD = "b0".asUInt(1.W) // int load
   val M_XWR = "b1".asUInt(1.W) // int store

   val DPORT = 0
   val IPORT = 1
}

class BusMemPort(val size: Int) extends Bundle {
  val req = Decoupled(new BusMemReq(size))
  val res = Flipped(Decoupled(new BusMemRes(size)))
}

class BusMemReq(val size: Int) extends Bundle {
  val addr = Output(UInt(32.W)) //p(sodor_xprlen)
  val data = Output(UInt(size.W))
  val memFunc = Output(UInt(MemoryOpcode.M_X.getWidth.W))  // memory function code
  val memType = Output(UInt(MemoryOpcode.MT_X.getWidth.W)) // memory type
}

class BusMemRes(val size: Int) extends Bundle {
  val data = Output(UInt(size.W))
}

object BusCommand {
   // req
                                 //   hit    |    miss
  def READ           = "b0000".U //  read    |   refill
  def WRITE          = "b0001".U //  write   |   refill
  def READ_BURST     = "b0010".U //  read    |   refill
  def WRITE_BURST    = "b0011".U //  write   |   refill
  def WRITE_LAST     = "b0111".U //  write   |   refill
  def PROBE          = "b1000".U //  read    | do nothing
  def PREFETCH       = "b0100".U //  read    |   refill

  // resp
  def READ_LAST      = "b0110".U
  def WRITE_RESP     = "b0101".U
  def PROBE_HIT      = "b1100".U
  def PROBE_MISS     = "b1000".U

  def apply() = UInt(4.W)
} 

class BusReqPort(val userBits: Int = 0, val addrBits: Int = 32, val idBits: Int = 0) extends BusPort {

  val addr = Output(UInt(addrBits.W))
  val size = Output(UInt(3.W))

  val command = Output(BusCommand())

  val maskW = Output(UInt((DATA_BITS / 8).W))
  val dataW = Output(UInt(DATA_BITS.W))

  val user = if (userBits > 0) Some(Output(UInt(userBits.W))) else None
  val id =   if (idBits > 0)   Some(Output(UInt(idBits.W)))   else None

  def apply(addr: UInt, command: UInt, size: UInt, dataW: UInt, maskW: UInt, user: UInt = 0.U, id: UInt = 0.U) : Unit = {
    this.addr := addr
    this.command := command
    this.size := size
    this.dataW := dataW
    this.maskW := maskW
    this.user.map(_ := user)
    this.id.map(_ := id)
  }

  def isRead()  = !command(0) && !command(3)
  def isWrite() = command(0)
  def isBurst() = command(1)

  def isReadBurst()   = command === BusCommand.READ_BURST
  def isWriteSingle() = command === BusCommand.WRITE
  def isWriteLast()   = command === BusCommand.WRITE_LAST
  def isProbe()       = command === BusCommand.PROBE
  def isPrefetch()    = command === BusCommand.PREFETCH
}


class BusResPort(val userBits: Int = 0, val idBits: Int = 0) extends BusPort {

  val command = Output(BusCommand())
  val dataR = Output(UInt(DATA_BITS.W))

  val user = if (userBits > 0) Some(Output(UInt(userBits.W))) else None
  val id =   if (idBits > 0)   Some(Output(UInt(idBits.W)))   else None

  def isReadLast()    = command === BusCommand.READ_LAST
  def isProbeHit()    = command === BusCommand.PROBE_HIT
  def isProbeMiss()   = command === BusCommand.PROBE_MISS
  def isWriteResp()   = command === BusCommand.WRITE_RESP
  def isPrefetch()    = command === BusCommand.PREFETCH
}

class BusMemConv(outType: BusMemPort) extends Module {

  val io = IO(new Bundle {
    val in = Flipped(new BusUncached)
    val out = Flipped(Flipped(outType))
  })

  io.in.req.ready := io.out.req.ready
  io.in.res.valid := io.out.res.valid
  io.out.req.valid := io.in.req.valid
  io.out.res.ready := io.in.res.ready

  // check MemPort implementation
  io.out.req.bits.addr := io.in.req.bits.addr
  io.out.req.bits.data := io.in.req.bits.dataW

  io.out.req.bits.memFunc := Mux(
    io.in.req.bits.isRead(),
    MemoryOpcode.M_XRD,
    MemoryOpcode.M_XWR,
  )

  io.out.req.bits.memType := MemoryOpcode.MT_W

  io.in.res.bits.dataR := io.out.res.bits.data
  io.in.res.bits.command := BusCommand.READ_LAST
} 

object BusMemConv {
  def apply(in: BusUncached, outType: BusMemPort): BusMemPort = {
    val bridge = Module(new BusMemConv(outType))
    bridge.io.in <> in
    bridge.io.out
  }
}


// Cache
class BusCached(val userBits: Int = 0) extends BusPort {
  val mem = new BusUncached(userBits)
  val coh = Flipped(new BusUncached(userBits))
}

// Uncached
class BusUncached(val userBits: Int = 0, val addrBits: Int = 32, val idBits: Int = 0) extends BusPort {
  
  val req = Decoupled(new BusReqPort(userBits, addrBits, idBits))

  val res = Flipped(Decoupled(new BusResPort(userBits, idBits)))

  def isWrite() = req.valid && req.bits.isWrite()
  def isRead()  = req.valid && req.bits.isRead()

  // def toAXI4Lite() = SimpleBus2AXI4Converter(this, new AXI4Lite, false)
  // def toAXI4(isFromCache: Boolean = false) = SimpleBus2AXI4Converter(this, new AXI4, isFromCache)
  def toMemPort() = BusMemConv(this, new BusMemPort(32))

}

class BusCrossbarFrom(n: Int, userBits:Int = 0) extends Module {

  val io = IO(new Bundle {
    val in = Flipped(Vec(n, new BusUncached(userBits)))
    val out = new BusUncached(userBits)
  })

  val s_idle :: s_read_resp :: s_write_resp :: Nil = Enum(3)
  val state = RegInit(s_idle)

  val lockWrite = (x: BusReqPort) =>
    x.isWrite() && x.isBurst()
  
  val inputArbiter = Module(new LockingArbiter(chiselTypeOf(io.in(0).req.bits), n, 8, Some(lockWrite)))

  (inputArbiter.io.in zip io.in.map(_.req))
    .map{
      case (arb, in) => arb <> in
    }
  
  val reqArbitrated = inputArbiter.io.out

  val inflightSrc = RegInit(0.U(log2Up(n).W))

  io.out.req.bits := reqArbitrated.bits

  // ready/valid of output req is binded to the inputArbiter 
  io.out.req.valid := reqArbitrated.valid && (state === s_idle)
  reqArbitrated.ready := io.out.req.ready && (state === s_idle)

  io.in.map(_.res.bits := io.out.res.bits)
  io.in.map(_.res.valid := false.B)

  (io.in(inflightSrc).res, io.out.res) match { case (l, r) => {
    l.valid := r.valid
    r.ready := l.ready
  }}

  switch (state) {
    is (s_idle) {
      when (reqArbitrated.fire) {
        inflightSrc := inputArbiter.io.chosen
        when (reqArbitrated.bits.isRead()) { 
          state := s_read_resp 
        }
        .elsewhen (reqArbitrated.bits.isWriteLast() || reqArbitrated.bits.isWriteSingle()) {
          state := s_write_resp
        }
      }
    }
    is (s_read_resp) { 
      when (io.out.res.fire && io.out.res.bits.isReadLast()) {
        state := s_idle
      }
    }
    is (s_write_resp) { 
      when (io.out.res.fire) {
       state := s_idle
      }
    }
  }
}

